#!/usr/bin/env python3

import json
import os
import sys
from collections import OrderedDict
from typing import Optional, Tuple

import jinja2

class VrTopo:
    """ vrnetlab topo builder
    """
    def __init__(self, config, keep_order):
        self.keep_order = keep_order
        self.links = []
        self.fullmeshes = {}
        self.hubs = {}
        if 'routers' in config:
            self.routers = config['routers']
        else:
            self.routers = {}

        # sanity checking - use a YANG model and pyang to validate input?
        for r, val in self.routers.items():
            if 'type' not in val:
                raise ValueError("'type' is not defined for router %s" % r)
            if val['type'] not in ('dummy', 'xcon', 'bgp', 'xrv', 'xrv9k', 'vmx', 'sros', 'csr', 'c8000v', 'nxos', 'nxos9kv', 'ucpe-oneos', 'vqfx', 'vrp', 'veos', 'openwrt'):
                raise ValueError("Unknown type %s for router %s" % (val['type'], r))

        # preserve order of entries if keep_order is set
        def maybe_sorted(d):
            if keep_order:
                return d
            else:
                return sorted(d)

        links = []

        # expand p2p links
        def p2p(instance):
            nonlocal links
            if 'p2p' in config:
                for router in maybe_sorted(config['p2p']):
                    neighbors = config['p2p'][router]
                    for neighbor in neighbors:
                        name, numeric = instance.router_name_interface(neighbor)
                        links.append({ 'left': { 'router': router }, 'right': {
                            'router': name, 'numeric': numeric }})

        # expand fullmesh into links
        def fullmesh():
            nonlocal links
            if 'fullmeshes' in config:
                for name in maybe_sorted(config['fullmeshes']):
                    val = config['fullmeshes'][name]
                    fmlinks = self.expand_fullmesh(val)
                    links.extend(fmlinks)

        # expand fullmeshes before p2p if keep_order is set and fullmeshes is defined before p2p
        if keep_order and 'fullmeshes' in config and 'p2p' in config and tuple(config).index('fullmeshes') < tuple(config).index('p2p'):
            fullmesh()
            p2p(self)
        else:
            p2p(self)
            fullmesh()

        self.links = self.assign_interfaces(links)

        self.links_by_nodes = OrderedDict()
        for l in self.links:
            for (link, a, b) in ((l, 'left', 'right'), (l, 'right', 'left')):
                if link[a]['router'] not in self.links_by_nodes:
                    self.links_by_nodes[link[a]['router']] = OrderedDict()
                spec = {'our_interface': link[a]['interface'],
                        'their_interface': link[b]['interface'],
                        'our_numeric': link[a]['numeric'],
                        'their_numeric': link[b]['numeric']}
                if link[b]['router'] not in self.links_by_nodes[link[a]['router']]:
                    self.links_by_nodes[link[a]['router']][link[b]['router']] = []
                self.links_by_nodes[link[a]['router']][link[b]['router']].append(spec)

        if 'hubs' in config:
            for hub in sorted(config['hubs']):
                self.hubs[hub] = []
                for router in config['hubs'][hub]:
                    name, interface = self.router_name_interface(router)
                    ep = {
                        'router': name,
                        'numeric': self.get_interface(name, interface)
                    }
                    ep['interface'] = self.intf_num_to_name(router, ep['numeric'])
                    self.hubs[hub].append(ep)

        for router in sorted(self.routers):
            val = self.routers[router]
            if 'interfaces' in val:
                for num_id in val['interfaces']:
                    val['interfaces'][num_id] = self.intf_num_to_name(router, num_id)

    @staticmethod
    def router_name_interface(router: str) -> Tuple[str, Optional[int]]:
        if ':' in router:
            s = router.split(':')
            return s[0], int(s[1])
        else:
            return router, None

    @staticmethod
    def expand_fullmesh(routers):
        """ Flatten a full-mesh into a list of links

            Links are considered bi-directional, so you will only see a link A->B
            and not a B->A.
        """
        pairs = {}
        for a in sorted(routers):
            for b in sorted(routers):
                left = min(a, b)
                right = max(a, b)

                if left == right: # don't create link to ourself
                    continue

                if left not in pairs:
                    pairs[left] = {}
                pairs[left][right] = 1

        links = []
        for a in sorted(pairs):
            for b in sorted(pairs[a]):
                links.append({'left': { 'router': a }, 'right': { 'router': b }})

        return links


    def assign_interfaces(self, links):
        """ Assign numeric interfaces to links
        """
        # first pass - reserve and validate static interface assignments
        for link in links:
            right = link['right']
            if right_numeric := right.get('numeric', None):
                self.get_interface(right['router'], right_numeric)

        # assign interfaces to links
        for link in links:
            left = link['left']
            left['numeric'] = self.get_interface(left['router'], left.get('numeric', None))
            left['interface'] = self.intf_num_to_name(left['router'], left['numeric'])

            right = link['right']
            right['numeric'] = self.get_interface(right['router'], right.get('numeric', None))
            right['interface'] = self.intf_num_to_name(right['router'], right['numeric'])

        return links


    def intf_num_to_name(self, router, interface):
        """ Map numeric ID to interface name
        """
        r = self.routers[router]
        if r['type'] == 'xrv' or r['type'] == 'xrv9k':
            return "GigabitEthernet0/0/0/%d" % (interface-1)
        if r['type'] == 'nxos9kv' or r['type'] == 'nxos':
            return "Ethernet1/%d" % (interface)
        elif r['type'] == 'vmx':
            return "ge-0/0/%d" % (interface-1)
        elif r['type'] == 'sros':
            return "{}/1/{}".format(1+int((interface-1)/6), 1+(interface-1)%6)
        elif r['type'] == 'bgp':
            return "tap%d" % (interface-1)
        elif r['type'] == 'csr' or r['type'] == 'c8000v':
            return "GigabitEthernet%d" % (interface+1)
        elif r['type'] == 'ucpe-oneos':
            return "eth%d" % (interface)
        elif r['type'] == 'vqfx':
            return "xe-0/0/%d" % (interface-1)
        elif r['type'] == 'vrp':
            return "GigabitEthernet4/0/%d" % (interface)
        elif r['type'] == 'veos':
            return "Ethernet%d" % (interface)
        elif r['type'] == 'openwrt':
            return "eth%d" % (interface)

        return None


    def get_interface(self, router, interface=None):
        """ Return next available interface, or allocate desired numeric interface
        """
        if router not in self.routers:
            raise ValueError("Router %s is not defined in config" % router)
        if 'interfaces' not in self.routers[router]:
            self.routers[router]['interfaces'] = {}
        intfs = self.routers[router]['interfaces']
        if interface:
            if interface in intfs and intfs[interface] != interface:
                raise ValueError(f"Interface {interface} already assigned for {router}")
            intfs[interface] = interface
            return interface

        i = 1
        for intf in range(len(intfs)):
            if i not in intfs:
                break
            i += 1

        intfs[i] = None

        return i


    def output(self, output_format='json'):
        """ Output the resulting topology in given format

            output_format can only be json for now
        """
        output = {
                'routers': self.routers,
                'links': self.links,
                'links_by_nodes': self.links_by_nodes,
                'hubs': self.hubs
            }

        if output_format == 'json':
            return json.dumps(output, sort_keys=(not self.keep_order), indent=4)
        else:
            raise ValueError("Invalid output format")


def run_command(cmd, dry_run=False):
    if dry_run:
        print(" ".join(cmd))
        return

    import subprocess
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    return p.communicate()


def run_topology(config, dry_run, with_trace=False, xcon_ng=False):
    if 'routers' not in config:
        print("No routers in config")
        sys.exit(1)

    trace = ''
    if with_trace:
        trace = '--trace'

    docker_networks = list(set([r['docker_network'] for r in config['routers'].values() if 'docker_network' in r]))
    if len(docker_networks) > 1:
        print("At most 1 docker network allowed")
        sys.exit(1)
    try:
        docker_network = docker_networks[0]
        run_command(["docker", "network", "create", docker_network, "||", "true"], dry_run)
    except IndexError:
        docker_network = None

    docker_registry = ""
    if os.getenv("DOCKER_REGISTRY"):
        docker_registry = os.getenv("DOCKER_REGISTRY") + "/"
    else:
        docker_registry = 'vrnetlab/'

    for router in sorted(config['routers']):
        val = config['routers'][router]
        if val["type"] == "dummy":
            continue

        name = "%s%s" % (args.prefix, router)
        cmd = ["docker", "run", "--privileged", "-d",
               "--name", name
        ]
        if 'docker_network' in val:
            cmd.append('--network ' + val['docker_network'])
            cmd.append('--network-alias ' + router)
        if 'ip' in val:
            cmd.append('--ip {}'.format(val['ip']))

        cmd.append("%svr-%s:%s" % (docker_registry, val["type"], val["version"]))
        if trace:
            cmd.append(trace)
        if 'run_args' in val:
            if isinstance(val['run_args'], str):
                cmd.extend(val["run_args"].split())
            # pass the keys and values unquoted, leave it up to the user to quote/escape or add dashes
            elif isinstance(val['run_args'], list):
                cmd.extend([str(v) for v in val['run_args']])
            elif isinstance(val['run_args'], dict):
                for k, v in val['run_args'].items():
                    cmd.extend([str(k), str(v)])

        output,_ = run_command(["docker", "inspect", "--format", "{{.State.Running}}", name])
        if not dry_run and output.strip() == "true":
            output,_ = run_command(["docker", "inspect", "--format", "{{.State.Health.Status}}", name])
            print("Container already running. Health: %s" % output.strip())
        else:
            run_command(cmd, dry_run)

    if 'links' in config:
        name = "%svr-xcon" % args.prefix
        cmd = ["docker", "run", "--rm", "--privileged", "-d", "--name", name]
        if xcon_ng:
            cmd.extend(["--entrypoint", "/xcon"])

        if docker_network:
            cmd.extend(["--network", docker_network])
        else:
            for vr in sorted(config['routers']):
                cmd.extend(["--link", "%s%s:%s%s" % (args.prefix, vr, args.prefix, vr)])
        cmd.append(docker_registry + "vr-xcon")
        cmd.append("--p2p")
        cmd.extend(["%s%s/%s--%s%s/%s" % (args.prefix, link["left"]["router"],
                                           link["left"]["numeric"],
                                           args.prefix,
                                           link["right"]["router"],
                                           link["right"]["numeric"]) for link in config['links']])
        output,_ = run_command(["docker", "inspect", "--format", "{{.State.Running}}", name])
        if not dry_run and output.strip() == "true":
            output,_ = run_command(["docker", "inspect", "--format", "{{.State.Health.Status}}", name])
            print("Container already running. Health: %s" % output.strip())
        else:
            run_command(cmd, dry_run)

    if 'hubs' in config:
        for hub, eps in config['hubs'].items():
            name = "{}vr-xcon-hub-{}".format(args.prefix, hub)
            cmd = ["docker", "run", "--privileged", "-d", "--name", name]

            if docker_network:
                cmd.extend(["--network", docker_network])
            else:
                for vr in sorted(config['routers']):
                    cmd.extend(["--link", "%s%s:%s%s" % (args.prefix, vr, args.prefix, vr)])
            cmd.append(docker_registry + "vr-xcon")
            cmd.append("--hub")
            cmd.extend(["%s%s/%s" % (args.prefix, ep["router"], ep["numeric"]) for ep in eps])
            run_command(cmd, dry_run)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--build", help="Build topology from config")
    parser.add_argument("--keep-order", action="store_true", help="Keep topology order (routers, links, fullmeshes). Applies to --build only")
    parser.add_argument("--run", help="Run topology")
    parser.add_argument("--dry-run", action="store_true", default=False, help="Only print what would be performed during --run")
    parser.add_argument("--prefix", default='', help="docker container name prefix")
    parser.add_argument("--template", nargs=2, help="produce output based on topology information and a template")
    parser.add_argument("--variable", action='append', help="store variables")
    parser.add_argument("--with-trace", action='store_true', help="run virtual routers with --trace")
    parser.add_argument("--xcon-ng", action='store_true', help="use xcon-ng (acton)")
    args = parser.parse_args()

    if args.dry_run and not args.run:
        print("ERROR: --dry-run is only relevant with --run")
        sys.exit(1)

    if args.prefix and not args.run:
        print("ERROR: --prefix is only relevant with --run")
        sys.exit(1)

    if args.build:
        input_file = open(args.build, "r")
        config = json.loads(input_file.read(), object_pairs_hook=OrderedDict)
        input_file.close()
        try:
            vt = VrTopo(config, args.keep_order)
        except Exception as exc:
            print("ERROR:", exc)
            sys.exit(1)
        print(vt.output())

    if args.run:
        input_file = open(args.run, "r")
        config = json.loads(input_file.read(), object_pairs_hook=OrderedDict)
        input_file.close()
        if args.dry_run:
            print("The following commands would be executed:")
        run_topology(config, args.dry_run, args.with_trace, args.xcon_ng)

    if args.template:
        input_file = open(args.template[0], "r")
        config = json.loads(input_file.read(), object_pairs_hook=OrderedDict)
        input_file.close()

        import sys
        vs = {}
        if args.variable:
            for var in args.variable:
                key,value = var.split("=", 2)
                vs[key] = value

        env = jinja2.Environment(loader=jinja2.FileSystemLoader(['./']))
        template = env.get_template(args.template[1])
        print(template.render(config=config, vars=vs))

